//
//  daes.c
//  AESTest
//
//  Created by zmjios on 2016/10/26.
//  Copyright © 2016年 zmjios. All rights reserved.
//

#include "daes.h"

static int Nb = 4;
static int Nr = 10;
static int Nk = 4;

static uint32_t s_box[10] =
{
    // 0    1     2     3     4     5     6     7     8     9
       6,   8,    4,    9,    7,    2,    1,    0,    5,    3
};

static uint32_t inv_s_box[10] =
{
    // 0    1     2     3     4     5     6     7     8     9
       7,   6,    5,    9,    2,    8,    0,    4,    1,    3
};


//static uint32_t s_box[10] =
//{
//    // 0    1     2     3     4     5     6     7     8     9
//       8,   5,    3,    7,    2,    4,    9,    6,    1,    0
//};
//
//static uint32_t inv_s_box[10] =
//{
//    // 0    1     2     3     4     5     6     7     8     9
//       9,   8,    4,    2,    5,    1,    7,    3,    0,    6
//};

/**
 S盒替换

 @param state 状态数组
 */
void sub_bytes(uint32_t *state, uint32_t len) {
    
    uint32_t i;
    for(i = 0; i < len; i++){
        state[i] = s_box[state[i]];
    }
}



/**
 S盒的逆

 @param state 状态数组
 */
void inv_sub_bytes(uint32_t *state, uint32_t len) {

    uint32_t i;
    for(i = 0; i < len; i++){
        state[i] = inv_s_box[state[i]];
    }
}


/*
 * 行位移，列数应该小于等于7
 *  Row0: s0  s4  s8  s12   <<< 0 byte
 * 	Row1: s1  s5  s9  s13   <<< 1 byte
 * 	Row2: s2  s6  s10 s14   <<< 2 bytes
 * 	Row3: s3  s7  s11 s15   <<< 3 bytes
 */
void shift_rows(uint32_t *state) {
    
    uint32_t i, k, s, tmp;
    
    //state数组排成4*x的矩阵格式
    for (i = 1; i < 4; i++) {
        
        s = 0;
        while (s < i) {
            tmp = state[Nb*i+0];
            
            for (k = 1; k < Nb; k++) {
                state[Nb*i+k-1] = state[Nb*i+k];
            }
            
            state[Nb*i+Nb-1] = tmp;
            s++;
        }
    }
}

/*
 * 反转行位移
 */
void inv_shift_rows(uint32_t *state) {
    
    uint32_t i, k, s, tmp;
    
    for (i = 1; i < 4; i++) {
        s = 0;
        while (s < i) {
            tmp = state[Nb*i+Nb-1];
            
            for (k = Nb-1; k > 0; k--) {
                state[Nb*i+k] = state[Nb*i+k-1];
            }
            
            state[Nb*i+0] = tmp;
            s++;
        }
    }
}


/*
 * 列混合
 * [s0]     [5 3 2 1]   [s0]
 * [s1]     [1 5 3 2] . [s1]
 * [s2]  =  [2 1 5 3]   [s2] mod 10
 * [s3]     [3 2 1 5]   [s3]
 */
void mix_columns(uint32_t *state) {
    
    uint32_t i, j, col[4], res[4];
    
    for (j = 0; j < Nb; j++) {
        for (i = 0; i < 4; i++) {
            col[i] = state[Nb*i+j];
        }
        
        res[0] = (5*col[0]+ 3*col[1] + 2*col[2] + 1*col[3]) % 10;
        res[1] = (1*col[0]+ 5*col[1] + 3*col[2] + 2*col[3]) % 10;
        res[2] = (2*col[0]+ 1*col[1] + 5*col[2] + 3*col[3]) % 10;
        res[3] = (3*col[0]+ 2*col[1] + 1*col[2] + 5*col[3]) % 10;

        for (i = 0; i < 4; i++) {
            state[Nb*i+j] = res[i];
        }
    }
}

/*
 * 反转列混合
 * [s0]     [5 9 4 3]   [s0]
 * [s1]     [3 5 9 4] . [s1]
 * [s2]  =  [4 3 5 9]   [s2] mod 10
 * [s3]     [9 4 3 5]   [s3]
 */
void inv_mix_columns(uint32_t *state) {
    
    uint32_t i, j, col[4], res[4];
    
    for (j = 0; j < Nb; j++) {
        for (i = 0; i < 4; i++) {
            col[i] = state[Nb*i+j];
        }
        
        res[0] = (5*col[0]+ 9*col[1] + 4*col[2] + 3*col[3]) % 10;
        res[1] = (3*col[0]+ 5*col[1] + 9*col[2] + 4*col[3]) % 10;
        res[2] = (4*col[0]+ 3*col[1] + 5*col[2] + 9*col[3]) % 10;
        res[3] = (9*col[0]+ 4*col[1] + 3*col[2] + 5*col[3]) % 10;
        
        for (i = 0; i < 4; i++) {
            state[Nb*i+j] = res[i];
        }
    }
}


/**
 轮秘钥控制运算
 
 @param state 状态矩阵数组
 @param k     秘钥
 @param time  次数
 @param len   状态数组长度
 */
void control_round_key(uint32_t *state, uint32_t *k, uint32_t time, uint32_t len) {
    
    uint32_t i0,enter,i,j;
    i0 = len * time;
    enter = 0;
    for(i = i0; i < i0 + len - 1; i ++)
    {
        enter = (enter + k[i]) % 4;
    }
    
    for (i = 0 ; i < len; i ++)
    {
        j = (i + enter) % 4;
        if (j == 0) {
            state[i] = (state[i] + k[i + i0]) % 10;
        }
        else if(j == 1)
        {
            int temp =  state[i] - k[i + i0];
            if (temp < 0) {
                temp += 10;
            }
            state[i] = temp % 10;
        }
        else if(j == 2)
        {
            state[i] = (state[i] + s_box[k[i + i0]]) % 10;
        }
        else if(j == 3)
        {
            int temp =  state[i] - s_box[k[i + i0]];
            if (temp < 0) {
                temp += 10;
            }
            state[i] = temp % 10;
        }
    }
}



/**
 反转轮秘钥控制运算

 @param state 状态矩阵数组
 @param k     秘钥
 @param time  次数
 @param len   状态数组长度
 */
void inv_control_round_key(uint32_t *state, uint32_t *k, uint32_t time, uint32_t len)
{
    uint32_t i0,enter,i,j;
    i0 = len * time;
    enter = 0;
    for(i = i0; i < i0 + len - 1; i ++)
    {
        enter = (enter + k[i]) % 4;
    }
    
    for (i = 0 ; i < len; i ++)
    {
        j = (i + enter) % 4;
        if (j == 0) {
            int temp =  state[i] - k[i + i0];
            if (temp < 0) {
                temp += 10;
            }
            state[i] = temp % 10;
        }
        else if(j == 1)
        {
            state[i] = (state[i] + k[i + i0]) % 10;
        }
        else if(j == 2)
        {
            int temp =  state[i] - s_box[k[i + i0]];
            if (temp < 0) {
                temp += 10;
            }
            state[i] = temp % 10;
            
        }
        else if(j == 3)
        {
            state[i] = (state[i] + s_box[k[i + i0]]) % 10;
        }
    }
}


/*
 * Function used in the Key Expansion routine that takes a four-byte
 * word and performs a cyclic permutation.
 */
void rot_word(uint32_t *w) {
    
    uint32_t tmp;
    uint32_t i;
    
    tmp = w[0];
    
    for (i = 0; i < 3; i++) {
        w[i] = w[i+1];
    }
    
    w[3] = tmp;
}

uint32_t R[] = {0, 0, 0, 0};

uint32_t * Rcon(uint32_t i) {
    
    R[0] = i + 1;
    
    return R;
}

void key_expansion(uint32_t *key, uint32_t *w) {
    
    uint32_t tmp[4];
    uint32_t i;
    uint32_t len = Nb*(Nr+1);
    
    for (i = 0; i < Nk; i++) {
        w[4*i+0] = key[4*i+0];
        w[4*i+1] = key[4*i+1];
        w[4*i+2] = key[4*i+2];
        w[4*i+3] = key[4*i+3];
    }
    
    for (i = Nk; i < len; i++) {
        tmp[0] = w[4*(i-1)+0];
        tmp[1] = w[4*(i-1)+1];
        tmp[2] = w[4*(i-1)+2];
        tmp[3] = w[4*(i-1)+3];
        
        if (i % Nk == 0) {
            
            rot_word(tmp);
            tmp[0] = s_box[(tmp[0] + Rcon(i/Nk)[0]) % 10];
            tmp[1] = s_box[(tmp[1] + Rcon(i/Nk)[1]) % 10];
            tmp[2] = s_box[(tmp[2] + Rcon(i/Nk)[2]) % 10];
            tmp[3] = s_box[(tmp[3] + Rcon(i/Nk)[3]) % 10];
        }
        
        w[4*i+0] = (w[4*(i-Nk)+0] + tmp[0]) % 10;
        w[4*i+1] = (w[4*(i-Nk)+1] + tmp[1]) % 10;
        w[4*i+2] = (w[4*(i-Nk)+2] + tmp[2]) % 10;
        w[4*i+3] = (w[4*(i-Nk)+3] + tmp[3]) % 10;
    }
}


void cipher(uint32_t *in, uint32_t *out, uint32_t *w) {
    
    uint32_t state[4*Nb];
    uint32_t r, i, j;
    
    for (i = 0; i < 4; i++) {
        for (j = 0; j < Nb; j++) {
            state[Nb*i+j] = in[i+4*j];
        }
    }
    
    uint32_t len = (uint32_t)sizeof(state) / sizeof(state[0]);
    
    control_round_key(state, w,0,len);
    
    for (r = 1; r < Nr; r++) {
        
        sub_bytes(state,len);
        shift_rows(state);
        mix_columns(state);
        control_round_key(state, w, r,len);
    }
    
    sub_bytes(state,len);
    shift_rows(state);
    control_round_key(state, w, Nr,len);
    
    for (i = 0; i < 4; i++) {
        for (j = 0; j < Nb; j++) {
            out[i+4*j] = state[Nb*i+j];
        }
    }
}

void inv_cipher(uint32_t *in, uint32_t *out, uint32_t *w) {
    
    uint32_t state[4*Nb];
    uint32_t r, i, j;
    
    for (i = 0; i < 4; i++) {
        for (j = 0; j < Nb; j++) {
            state[Nb*i+j] = in[i+4*j];
        }
    }
    
    uint32_t len = (uint32_t)sizeof(state) / sizeof(state[0]);
    
    inv_control_round_key(state, w, Nr, len);
    
    for (r = Nr-1; r >= 1; r--) {
        
        inv_shift_rows(state);
        inv_sub_bytes(state,len);
        inv_control_round_key(state, w, r,len);
        inv_mix_columns(state);
    }
    
    inv_shift_rows(state);
    inv_sub_bytes(state,len);
    inv_control_round_key(state, w, 0,len);
    
    for (i = 0; i < 4; i++) {
        for (j = 0; j < Nb; j++) {
            out[i+4*j] = state[Nb*i+j];
        }
    }
}




